package org.springframework.security.oauth2.server.authorization.web;

import com.luo.sc.oidc.authserver.enums.AuthenticaionResultCodeEnum;
import com.luo.sc.oidc.authserver.handler.login.UniLoginRespJsonDto;
import com.luo.sc.oidc.authserver.handler.token.OAuth2TokenEndpointAuthenticationResponseHandler;
import com.luo.sc.oidc.authserver.utils.HttpContextUtils;
import com.luo.sc.oidc.authserver.utils.JsonUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
import org.springframework.util.CollectionUtils;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.time.temporal.ChronoUnit;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * Token Endpoint 响应结果处理器 - 默认实现
 *
 * @author luohq
 * @version 1.0.0
 * @date 2022-03-16 09:50
 * @see OAuth2TokenEndpointFilter
 */
@Slf4j
public class DefaultOAuth2TokenEndpointAuthenticationResponseHandlerImpl implements OAuth2TokenEndpointAuthenticationResponseHandler {

    private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter();
    private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();

    @Override
    public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, Authentication authentication) throws IOException, ServletException {
        /** 沿用原OAuth2TokenEndpointFilter.sendAccessTokenResponse逻辑 */
        OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = (OAuth2AccessTokenAuthenticationToken) authentication;

        OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
        OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
        Map<String, Object> additionalParameters = accessTokenAuthentication.getAdditionalParameters();

        OAuth2AccessTokenResponse.Builder builder =
                OAuth2AccessTokenResponse.withToken(accessToken.getTokenValue())
                        .tokenType(accessToken.getTokenType())
                        .scopes(accessToken.getScopes());
        if (accessToken.getIssuedAt() != null && accessToken.getExpiresAt() != null) {
            builder.expiresIn(ChronoUnit.SECONDS.between(accessToken.getIssuedAt(), accessToken.getExpiresAt()));
        }
        if (refreshToken != null) {
            builder.refreshToken(refreshToken.getTokenValue());
        }
        if (!CollectionUtils.isEmpty(additionalParameters)) {
            builder.additionalParameters(additionalParameters);
        }
        OAuth2AccessTokenResponse accessTokenResponse = builder.build();
        ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
        this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
    }

    @Override
    public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException {
        /** 兼容原逻辑OAuth2TokenEndpointFilter.sendErrorResponse，协议https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 */
        OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
        ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
        httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
        this.errorHttpResponseConverter.write(error, null, httpResponse);
    }

    //@Override
    //public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException exception) throws IOException, ServletException {
    //    /** 原逻辑参见OAuth2TokenEndpointFilter.sendErrorResponse，协议https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 */
    //    //通过响应200 Json返回结果（兼容grant_type=password模式）
    //    OAuth2Error oAuth2Error = ((OAuth2AuthenticationException) exception).getError();
    //    log.debug("Oauth2 Token endpoint authentication failure - {}", oAuth2Error);
    //    //构建认证失败的响应结果
    //    String respJson = JsonUtils.toJson(UniLoginRespJsonDto.builder()
    //            .code(AuthenticaionResultCodeEnum.FAILURE.getCode())
    //            .msg(this.convertErrorMsg(oAuth2Error))
    //            .build());
    //    log.debug("Oauth2 Token endpoint authentication failure resp: {}", respJson);
    //    HttpContextUtils.responseJson(respJson, response);
    //}
    //
    //
    ///**
    // * 转换error提示信息
    // *
    // * @param oAuth2Error
    // * @return
    // */
    //private String convertErrorMsg(OAuth2Error oAuth2Error) {
    //    return Stream.of(oAuth2Error.getErrorCode(), oAuth2Error.getDescription())
    //            .filter(error -> !AuthenticaionResultCodeEnum.FAILURE.getCode().equals(error))
    //            .collect(Collectors.joining(" - "));
    //}

}
